# -*- coding: utf-8 -*-
"""
/***************************************************************************
 ExeLibSVMDialog
                                 A QGIS plugin
 Execute LibSVM Predict
                             -------------------
        begin                : 2015-06-12
        copyright            : (C) 2015 by xushiluo
        email                : xushiluo@163.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

import os
import numpy

from osgeo import gdal, ogr, osr
from osgeo.gdalconst import *

from PyQt4 import QtCore, QtGui
from ui_exelibsvm import Ui_ExeLibSVM
# create the dialog for zoom to point

from svmDataScale import *

import sys
sys.path.append('c:\Program Files (x86)\JetBrains\PyCharm 3.4.1\pycharm-debug.egg')
import pydevd

scriptPath = os.path.split(os.path.realpath(__file__))[0]
sys.path.append(scriptPath+"/libsvm320/python")
from svmutil import *

# normalize ndarray
def normalize_func(minVal, maxVal, newMinValue=0, newMaxValue=1 ):
    def normalizeFunc(x):
        r=(x-minVal)*newMaxValue/(maxVal-minVal) + newMinValue
        return r
    return numpy.frompyfunc(normalizeFunc,1,1)

class ExeLibSVMDialog(QtGui.QDialog, Ui_ExeLibSVM):
    def __init__(self):
        QtGui.QDialog.__init__(self)
        self.setupUi(self)

        # Initial configuration of the user interface
        self.configUI()

        # Set up the signals
        self.connectSignals()

        # Additional code
        self.inFileName = None
        self.outFileName = None
        self.projSrcFileName = None
        self.outImgColumns = None
        self.outImgRows = None
        self.percentForTrainning = 10

        # self.loadTestData()

    def configUI(self):
        self.inputLineEdit.setReadOnly(True)
        self.outputLineEdit.setReadOnly(True)
        self.okButton.setDisabled(True)

    def connectSignals(self):
        self.cancelButton.clicked.connect(self.close)
        self.inputButton.clicked.connect(self.showOpenDialog)
        self.outputButton.clicked.connect(self.showSaveDialog)
        self.okButton.clicked.connect(self.executeLibSVM)
        self.projectionSrcPushButton.clicked.connect(self.showProjSrcDlg)
        # self.outImgColumnslineEdit.textChanged.connect(self.updateImgColumnsText)
        # self.outImgRowslineEdit.textChanged.connect(self.updateImgRowsText)
        self.percentForTrainninglineEdit.textChanged.connect(self.updatePercentForTrainning)

    def showOpenDialog(self):
        fileName = str(QtGui.QFileDialog.getOpenFileName(self, "Input libSVM data File:", "", "libSVM data (*.svm *.dat *.csv *.txt)"))

        if len(fileName) is 0:
            return
        else:
            self.inFileName = fileName

        if self.inFileName is not None and self.outFileName is not None:
            self.okButton.setDisabled(False)

        self.inputLineEdit.clear()
        self.inputLineEdit.setText(self.inFileName)

    def showProjSrcDlg(self):
        fileName = str(QtGui.QFileDialog.getOpenFileName(self, "Projection Src Tiff Image:", "", "Tiff Image (*.tif *.tiff)"))

        if len(fileName) is 0:
            return
        else:
            self.projSrcFileName = fileName

        if self.projSrcFileName is not None and self.outFileName is not None:
            self.okButton.setDisabled(False)

        self.projectionSrcLineEdit.clear()
        self.projectionSrcLineEdit.setText(self.projSrcFileName)

    def showSaveDialog(self):
        #self.outFileName = str(QtGui.QFileDialog.getSaveFileName(self,
        #        'Output Raster File:', '', '*.tif'))

        # Declare the filetype in which to save the output file
        # Currently the plugin only supports GeoTIFF files
        fileTypes = 'Tiff Files (*.tif *.tiff)'
        fileName, filter = QtGui.QFileDialog.getSaveFileNameAndFilter(
            self, 'Output Probability Image:', '', fileTypes)

        if len(fileName) is 0:
            return
        else:
            # Extract the base filename without the suffix if it exists
            # Convert the fileName from QString to python string
            fileNameStr = str(fileName)

            # Split the fileNameStr where/if a '.' exists
            splittedFileName = fileNameStr.split('.')

            # Finally extract the base filename from the splitted filename
            baseFileName = splittedFileName[0]

            # Initialize the suffix string
            suffixStr = ''

            # Check if the user entered a suffix
            suffixExists = False
            existingSuffix = ''
            if len(splittedFileName) != 1:
                existingSuffix = splittedFileName[len(splittedFileName) - 1]
                if existingSuffix is not None:
                    suffixExists = True


            # Extract the suffix from the selected filetype filter
            # Convert the selected filter from QString to python string
            filterStr = str(filter)

            # Split the filter string where/if an asterisk (*) exists
            # I do this to find where the first suffix of the selected filetype
            # occurs
            splittedFilter = filterStr.split('*')

            # If a suffix is not supplied by the user it will be automatically
            # added to the filename. The default suffix will be the first
            # available suffix for the chosen filetype
            if not suffixExists:
                # Extract the 'dirty' suffix string where the first suffix is located
                dirtySuffixStr = splittedFilter[1]

                # Find out the number of the available suffixes
                suffixNum = len(splittedFilter) - 1

                if suffixNum == 1:
                    # Split the dirty suffix string where a ')' occurs
                    # which indicates where the selected filetype ends
                    splittedDirtySuffixStr = dirtySuffixStr.split(')')
                else:
                    # Split the dirty suffix string where a space occurs which
                    # indicates where the selected filetype suffix ends
                    splittedDirtySuffixStr = dirtySuffixStr.split(' ')
                suffixStr = splittedDirtySuffixStr[0]
            else:
                # WE NEED TO CHECK IF THE SUPPLIED SUFFIX CORRESPONDS TO THE
                # SELECTED FILETYPE

                # Extract all the suffixes available for the selected filetype
                # First find out the number of the available suffixes
                suffixNum = len(splittedFilter) - 1

                if suffixNum == 1:
                    # Extract the 'dirty' suffix string where the suffix is located
                    dirtySuffixStr = splittedFilter[1]

                    # Split the dirty suffix string where a space occurs which
                    # indicates where the selected filetype suffix ends
                    splittedDirtySuffixStr = dirtySuffixStr.split(' ')
                    suffixStr = splittedDirtySuffixStr[0]


                else:
                    suffixList = []
                    if suffixNum == 2:
                        # Extract the first suffix and put it in the list
                        dirtySuffixStr = splittedFilter[1]
                        splittedDirtySuffixStr = dirtySuffixStr.split(' ')
                        suffixList.append(splittedDirtySuffixStr[0])

                        # Extract the second suffix and put it in the list
                        dirtySuffixStr = splittedFilter[2]
                        splittedDirtySuffixStr = dirtySuffixStr.split(')')
                        suffixList.append(splittedDirtySuffixStr[0])

                    else:
                        # Extract the first suffix and put it in the list
                        dirtySuffixStr = splittedFilter[1]
                        splittedDirtySuffixStr = dirtySuffixStr.split(' ')
                        suffixList.append(splittedDirtySuffixStr[0])

                        # Extract the last suffix and put it in the list
                        dirtySuffixStr = splittedFilter[suffixNum]
                        splittedDirtySuffixStr = dirtySuffixStr.split(')')
                        suffixList.append(splittedDirtySuffixStr[0])

                        # Extract the rest of the suffixes and put them in the list
                        for i in xrange(2, suffixNum):
                            dirtySuffixStr = splittedFilter[i]
                            splittedDirtySuffixStr = dirtySuffixStr.split(' ')
                            suffixList.append(splittedDirtySuffixStr[0])

                    # Find if the user supplied suffix is valid for the
                    # chosen filetype and set it as the filename suffix
                    isValidSuffix = False
                    userSuffix = '.' + existingSuffix
                    for i in xrange(suffixNum + 1):
                        if userSuffix == suffixList[i]:
                            isValidSuffix = True
                            suffixStr = userSuffix
                            break

                    # If the supplied suffix is not valid replace it
                    # with the default suffix for the chosen filetype
                    if not isValidSuffix:
                        suffixStr = suffixList[0]

            self.outFileName = baseFileName + suffixStr

        if self.inFileName is not None and self.outFileName is not None:
            self.okButton.setDisabled(False)
        self.outputLineEdit.clear()
        self.outputLineEdit.setText(self.outFileName)

    # def updateImgColumnsText(self):
    #     try:
    #         self.outImgColumns = int(self.outImgColumnslineEdit.text())
    #     except ValueError:
    #         QtGui.QMessageBox.critical(self,"Critical",  "image columns input error!")
    #
    # def updateImgRowsText(self):
    #     try:
    #         self.outImgRows = int(self.outImgRowslineEdit.text())
    #     except ValueError:
    #         QtGui.QMessageBox.critical(self, "Critical",  "image rows input error!")

    def updatePercentForTrainning(self):
        try:
            self.percentForTrainning = float(self.percentForTrainninglineEdit.text())
        except ValueError:
            QtGui.QMessageBox.critical(self, "Critical",  "percent input error!")

    def updateInFileName(self):
        self.inFileName = self.inputLineEdit.text()

    def updateOutFileName(self):
        self.outFileName = self.outputLineEdit.text()

    def executeLibSVM(self):
        # pydevd.settrace('localhost',  port=53100,  stdoutToServer=True,  stderrToServer=True)
        self.updateInFileName()
        self.updateOutFileName()
        self.updatePercentForTrainning()

        gdal.AllRegister()
        # Open and assign the contents of the raster file to a dataset
        dataset = gdal.Open(self.projSrcFileName, GA_ReadOnly)
        band0_NodataValue = dataset.GetRasterBand(1).GetNoDataValue()
        self.outImgColumns = dataset.RasterXSize
        self.outImgRows = dataset.RasterYSize

        y, x = svm_read_problem(self.inFileName)
        yArray = numpy.array(y)
        xArray = numpy.array(x)
        # dataIndexBoolArray = (yArray != band0_NodataValue)

        dataXArray, dataIndexBoolArray = svmDataScale(self.inFileName, None, band0_NodataValue)
        dataYArray = yArray[dataIndexBoolArray]

        # tempDataXArray = xArray[dataIndexBoolArray]
        # minXValue = numpy.amin(tempDataXArray)
        # maxYValue = numpy.amax(tempDataXArray)


        print(numpy.amin(dataXArray))
        print(numpy.amax(dataXArray))

        totalSize = dataYArray.size
        trainSize = int(totalSize * self.percentForTrainning/100.0)
        indexArray = numpy.random.randint(trainSize, size=trainSize)

        trainY = [dataYArray[i] for i in indexArray]
        trainX = [dataXArray[i] for i in indexArray]

        m = svm_train(trainY, trainX, '-c 4 -b 1')
        p_label, p_acc, p_val = svm_predict(dataYArray, dataXArray, m, '-b 1')

        newProbability = [max(x) for x in p_val ]
        tempProbArray = numpy.array(newProbability)
        yArray[dataIndexBoolArray] = tempProbArray

        outArray = numpy.reshape(yArray,(self.outImgRows,self.outImgColumns))
        print(outArray.shape)

        format = "GTiff"
        driver = gdal.GetDriverByName( format )
        outDataset = driver.Create( self.outFileName,
                                    self.outImgColumns,
                                    self.outImgRows,
                                    1,
                                    gdal.GDT_Float32 )
        pBand = outDataset.GetRasterBand(1)
        pBand.SetNoDataValue(band0_NodataValue)
        pBand.WriteArray( outArray )

        if dataset.GetGeoTransform() != None:
            outDataset.SetGeoTransform(dataset.GetGeoTransform())
        if dataset.GetProjection() != None:
            outDataset.SetProjection(dataset.GetProjection())

        outDataset = None
        dataset = None

        self.close()

    # load test data
    def loadTestData(self):
        self.inputLineEdit.setText(r"h:\0RSIA_SVM\result\label_Features.svm")
        self.percentForTrainninglineEdit.setText("2")
        self.projectionSrcLineEdit.setText(r"h:\0RSIA_SVM\result\Label_Features.tif")
        # self.outImgColumnslineEdit.setText("1726")
        # self.outImgRowslineEdit.setText("557")
        self.outputLineEdit.setText(r"h:\0RSIA_SVM\result\11.tif")
        self.okButton.setDisabled(False)